// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.kudu.backup

import org.apache.hadoop.fs.Path
import org.apache.kudu.backup.Backup.TableMetadataPB
import org.apache.yetus.audience.InterfaceAudience
import org.apache.yetus.audience.InterfaceStability

import scala.collection.mutable

/**
 * A directed weighted graph of backups used to pick the optimal series of backups and restores.
 */
@InterfaceAudience.Private
@InterfaceStability.Unstable
class BackupGraph(val tableId: String) {
  // Index of backup.fromMs -> backup for use in chaining backups together.
  private val adjacencyList = mutable.Map[Long, mutable.ListBuffer[BackupNode]]()

  // A full backup has FromMs of 0.
  private val FullBackupFromMs = 0

  /**
   * Add a backup to the graph.
   * @param backup the backup to add.
   */
  def addBackup(backup: BackupNode): Unit = {
    // Add a weighted edge with the backup.
    addEdge(backup)
  }

  private def addEdge(backup: BackupNode): Unit = {
    val adjacentVertices =
      adjacencyList.getOrElse(backup.metadata.getFromMs, mutable.ListBuffer[BackupNode]())
    adjacentVertices += backup
    adjacencyList.put(backup.metadata.getFromMs, adjacentVertices)
  }

  /**
   * @return true if the graph has a full backup.
   */
  def hasFullBackup: Boolean = fullBackups.nonEmpty

  /**
   * @return all the full backups in the graph.
   */
  def fullBackups: Seq[BackupNode] = {
    adjacencyList.getOrElse(FullBackupFromMs, Seq())
  }

  /**
   * @return all the backups in the graph.
   */
  def allBackups: Seq[BackupNode] = {
    adjacencyList.values.flatten.toSeq
  }

  /**
   * @return the most recent full backup.
   * @throws IllegalStateException if no full backup exists.
   */
  def mostRecentFull: BackupNode = {
    if (!hasFullBackup) throw new IllegalStateException("No full backup exists")
    fullBackups.maxBy(_.metadata.getToMs)
  }

  /**
   * @return all backup paths in the graph.
   */
  def backupPaths: Seq[BackupPath] = {
    allPaths(FullBackupFromMs, List())
      .map(BackupPath)
      .filterNot(_.backups.isEmpty) // Remove empty paths
  }

  private def allPaths(fromMs: Long, path: List[BackupNode]): List[List[BackupNode]] = {
    if (!adjacencyList.contains(fromMs)) {
      List(path)
    } else {
      adjacencyList(fromMs).flatMap { node =>
        allPaths(node.metadata.getToMs, path ++ List(node))
      }.toList
    }
  }

  /**
   * Returns the backup that should be used as the base for the next backup.
   *
   * The logic for picking this backup is as follows:
   *
   *   1. Pick the paths with the most recent full backup.
   *   2. If there are multiple paths, pick the path with the most recent partial backup.
   *   3. If there are multiple paths, pick the path with the lowest weight.
   *
   * This allows concurrent full backups to be taken (or generated by compaction)
   * while also taking incremental backups.
   *
   * While a full backup is running incremental backups will continue to build
   * off the chain from the previous full. When the new full completes, the
   * next incremental backup will use that full its "current" chain.
   *
   * @throws IllegalStateException if no full backup exists.
   */
  def backupBase: BackupNode = {
    // 1. Pick the paths with the most recent full backup.
    val recentFulls = backupPaths.filter(_.fullBackup == mostRecentFull)

    // 2. If there are multiple paths, pick the path with the most recent partial backup.
    val maxToMs = recentFulls.maxBy(_.toMs).toMs
    val recentPaths = recentFulls.filter(_.toMs == maxToMs)

    // 3. If there are multiple paths, pick the path with the lowest weight.
    recentPaths.minBy(_.weight).lastBackup
  }

  /**
   * Returns a sequence of backups that should be used to restore.
   *
   * The logic for picking this path is as follows:
   *
   *   1. Pick the path with the most recent backup.
   *   2. If there are multiple paths, pick the path with the lowest weight.
   *
   * This ensures we always restore the most current state of the data while
   * also picking the most efficient path (likely a result of compaction).
   *
   * @throws IllegalStateException if no full backup exists.
   */
  def restorePath: BackupPath = {
    if (backupPaths.isEmpty) {
      throw new RuntimeException(s"No valid backups found for table ID: $tableId")
    }

    //  1. Pick the path with the most recent backup.
    val maxToMs = backupPaths.maxBy(_.toMs).toMs
    val recentPaths = backupPaths.filter(_.toMs == maxToMs)

    // 2. If there are multiple paths, pick the path with the lowest weight.
    recentPaths.minBy(_.weight)
  }

  /**
   * Returns a new BackupGraph that represents the graph including only nodes with a ToMS equal
   * to or less than the specified time.
   * @param timeMs the time to filter by.
   * @return
   */
  def filterByTime(timeMs: Long): BackupGraph = {
    val result = new BackupGraph(tableId)
    val distinctBackups = adjacencyList.values.flatten.toSet
    distinctBackups.filter(_.metadata.getToMs <= timeMs).foreach(result.addBackup)
    result
  }
}

/**
 * Node class to represent nodes in the backup graph.
 */
@InterfaceAudience.Private
@InterfaceStability.Unstable
case class BackupNode(path: Path, metadata: TableMetadataPB) {

  /**
   * @return The weight/cost of this Node.
   */
  def weight: Int = {
    // Full backups have a weight of 0 and partial backups have a weight of 1.
    if (metadata.getFromMs == 0) 0 else 1
  }
}

/**
 * A backup path is a full backup with a series of incremental backups.
 */
@InterfaceAudience.Private
@InterfaceStability.Unstable
case class BackupPath(backups: Seq[BackupNode]) {

  def fullBackup: BackupNode = backups.head

  def lastBackup: BackupNode = backups.last

  /**
   * @return the tableName for the entire path.
   */
  def tableName: String = backups.last.metadata.getTableName

  /**
   * @return the toMs for the entire path.
   */
  def toMs: Long = backups.last.metadata.getToMs

  /**
   * @return the weight/cost of the path.
   */
  def weight: Int = backups.map(_.weight).sum

  /**
   * @return A string useful for debugging the path.
   */
  def pathString: String = backups.map(_.metadata.getFromMs).mkString(" -> ")
}
